{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 机器学习练习 机器学习实践\n",
    "代码更新地址：https://github.com/fengdu78/WZU-machine-learning-course\n",
    "\n",
    "整理编译：黄海广 haiguang2000@wzu.edu.cn,光城\n",
    "\n",
    "最后更新：2023-7-24"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在本节教程中将会绘制几个图形，于是我们激活matplotlib,使得在notebook中显示内联图。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")#忽略警告"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 为什么要出这个教程？"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`scikit-learn` 提供最先进的机器学习算法。 但是，这些算法不能直接用于原始数据。 原始数据需要事先进行预处理。 因此，除了机器学习算法之外，scikit-learn还提供了一套预处理方法。此外，`scikit-learn` 提供用于流水线化这些估计器的连接器(即转换器，回归器，分类器，聚类器等)。\n",
    "\n",
    "在本教程中,将介绍`scikit-learn` 函数集，允许流水线估计器、评估这些流水线、使用超参数优化调整这些流水线以及创建复杂的预处理步骤。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.基本用例：训练和测试分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于第一个示例，我们将在数据集上训练和测试一个分类器。 我们将使用此示例来回忆`scikit-learn`的API。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将使用`digits`数据集，这是一个手写数字的数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_digits\n",
    "\n",
    "X, y = load_digits(return_X_y=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1797, 64)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`X`中的每行包含64个图像像素的强度。 对于`X`中的每个样本，我们得到表示所写数字对应的`y`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The digit in the image is 0\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAGMklEQVR4nO3bsU0DWRhGUXtFA9OCKcG0AiVACVCCe3EJUAJuwSXgEmazq9WKAD3JelicE0/wBeO5+gNv13VdNwCw2Wz+mT0AgN9DFACIKAAQUQAgogBARAGAiAIAEQUAcvfTB7fb7TV38D+Pj4+zJww7HA6zJwz5+PiYPWHI29vb7AlDLpfL7Al/zk/+q+xSACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAHI3ewDfOxwOsycM2+12sycMWZZl9oQhX19fsycMeXp6mj1h2PF4nD3halwKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQO5mD7i2/X4/e8KQ3W43e8Kw+/v72ROGnM/n2ROGvL+/z54w5FZ/m5vNZnM8HmdPuBqXAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACB3swdc27IssycMOZ1OsycMO5/Psyf8Kbf8rvD7uBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGA3M0ecG3LssyeMOTj42P2BG7Erb7jl8tl9gS+4VIAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAcjd7wLVdLpfZE4bs9/vZE/6cZVlmTxhyq+/K8XicPYFvuBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAbNd1XX/04HZ77S1XsdvtZk8Y8vn5OXvCsJeXl9kThjw+Ps6eMORW3/GHh4fZE/6cn3zuXQoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAbNd1XX/04HZ77S38x/Pz8+wJw15fX2dPGHI6nWZPGPL09DR7AjfiJ597lwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ7bqu6+wRAPwOLgUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFAPIvRrFVA6H5bgEAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(X[0].reshape(8, 8), cmap='gray');# 下面完成灰度图的绘制\n",
    "# 灰度显示图像\n",
    "plt.axis('off')# 关闭坐标轴\n",
    "\n",
    "print('The digit in the image is {}'.format(y[0]))# 格式化打印"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在机器学习中，我们应该通过在不同的数据集上进行训练和测试来评估我们的模型。`train_test_split` 是一个用于将数据拆分为两个独立数据集的效用函数。`stratify`参数可强制将训练和测试数据集的类分布与整个数据集的类分布相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, ..., 8, 9, 8])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X,\n",
    "                                                    y,\n",
    "                                                    stratify=y,\n",
    "                                                    test_size=0.25,\n",
    "                                                    random_state=42)\n",
    "# 划分数据为训练集与测试集,添加stratify参数，以使得训练和测试数据集的类分布与整个数据集的类分布相同。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一旦我们拥有独立的培训和测试集，我们就可以使用`fit`方法学习机器学习模型。 我们将使用`score`方法来测试此方法，依赖于默认的准确度指标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the LogisticRegression is 0.9622\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression  # 求出Logistic回归的精确度得分\n",
    "\n",
    "clf = LogisticRegression(solver='lbfgs',\n",
    "                         multi_class='ovr',\n",
    "                         max_iter=5000,\n",
    "                         random_state=42)\n",
    "# 创建逻辑回归模型clf，设置相关参数\n",
    "clf.fit(X_train, y_train)\n",
    "# 使用训练集X_train和y_train对模型进行训练\n",
    "accuracy = clf.score(X_test, y_test)\n",
    "# 使用测试集X_test和y_test计算模型的准确率\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))\n",
    "# 输出模型的准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ?clf.score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`scikit-learn`的API在分类器中是一致的。因此，我们可以通过`RandomForestClassifier`轻松替换`LogisticRegression`分类器。这些更改很小，仅与分类器实例的创建有关。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the RandomForestClassifier is 0.9711\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "# RandomForestClassifier轻松替换LogisticRegression分类器\n",
    "clf = RandomForestClassifier(n_estimators=1000, n_jobs=-1, random_state=42)\n",
    "clf.fit(X_train, y_train)\n",
    "accuracy = clf.score(X_test, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__, accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the XGBClassifier is 0.9578\n"
     ]
    }
   ],
   "source": [
    "from xgboost import XGBClassifier\n",
    "clf = XGBClassifier(n_estimators=1000)\n",
    "clf.fit(X_train, y_train)\n",
    "accuracy = clf.score(X_test, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__, accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the GradientBoostingClassifier is 0.9556\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "clf = GradientBoostingClassifier(n_estimators=100, random_state=0)\n",
    "clf.fit(X_train, y_train)\n",
    "accuracy = clf.score(X_test, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the GradientBoostingClassifier is 0.9573\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import balanced_accuracy_score\n",
    "y_pred = clf.predict(X_test)\n",
    "accuracy = balanced_accuracy_score(y_pred, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the SVC is 0.9911\n"
     ]
    }
   ],
   "source": [
    "from sklearn.svm import SVC, LinearSVC\n",
    "\n",
    "clf = SVC()\n",
    "clf.fit(X_train, y_train)\n",
    "accuracy = clf.score(X_test, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the LinearSVC is 0.9467\n"
     ]
    }
   ],
   "source": [
    "clf = LinearSVC()\n",
    "clf.fit(X_train, y_train)\n",
    "accuracy = clf.score(X_test, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 标准化您的数据\n",
    "在学习模型之前可能需要预处理。例如，一个用户可能对创建手工制作的特征或者算法感兴趣，那么他可能会对数据进行一些先验假设。\n",
    "\n",
    "在我们的例子中，线性模型使用的求解器期望数据被规范化。因此，我们需要在训练模型之前标准化数据。为了观察这个必要条件，我们将检查训练模型所需的迭代次数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MinMaxScaler`变换器用于归一化数据，`StandardScaler`用于标准化数据。该标量应该以下列方式应用：学习（即，`fit`方法）训练集上的统计数据并标准化（即，`transform`方法）训练集和测试集。 最后，我们将训练和测试这个模型并得到归一化后的数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler  # 导入MinMaxScaler和StandardScaler类\n",
    "scaler = MinMaxScaler()\n",
    "# 创建MinMaxScaler对象scaler，用于特征缩放\n",
    "X_train_scaled = scaler.fit_transform(X_train)\n",
    "# 使用训练集X_train对特征进行缩放，并将结果保存在变量X_train_scaled中\n",
    "X_test_scaled = scaler.transform(X_test)\n",
    "# 使用训练集X_test对特征进行缩放，并将结果保存在变量X_test_scaled中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the LinearSVC is 0.9689\n"
     ]
    }
   ],
   "source": [
    "clf = LinearSVC()\n",
    "clf.fit(X_train_scaled, y_train)\n",
    "accuracy = clf.score(X_test_scaled, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the LinearSVC is 0.9511\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler,StandardScaler\n",
    "scaler = StandardScaler()\n",
    "X_train_scaled = scaler.fit_transform(X_train)\n",
    "X_test_scaled = scaler.transform(X_test)\n",
    "clf = LinearSVC()\n",
    "clf.fit(X_train_scaled, y_train)\n",
    "accuracy = clf.score(X_test_scaled, y_test)\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix, classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = clf.predict(X_test_scaled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[45  0  0  0  0  0  0  0  0  0]\n",
      " [ 0 41  0  0  0  0  1  0  2  0]\n",
      " [ 0  0 44  0  0  0  0  0  0  0]\n",
      " [ 0  1  0 44  0  0  0  0  0  0]\n",
      " [ 0  1  0  0 45  0  0  1  0  1]\n",
      " [ 0  0  0  1  0 44  0  0  1  1]\n",
      " [ 0  1  0  0  0  0 43  0  2  0]\n",
      " [ 0  0  0  1  0  0  0 44  0  2]\n",
      " [ 0  2  0  0  0  0  1  0 38  1]\n",
      " [ 0  0  0  0  0  2  0  0  0 40]]\n"
     ]
    }
   ],
   "source": [
    "print(confusion_matrix(y_pred, y_test))#打印混淆矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>45</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>41</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>44</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>44</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>45</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>44</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>43</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>44</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    0   1   2   3   4   5   6   7   8   9\n",
       "0  45   0   0   0   0   0   0   0   0   0\n",
       "1   0  41   0   0   0   0   1   0   2   0\n",
       "2   0   0  44   0   0   0   0   0   0   0\n",
       "3   0   1   0  44   0   0   0   0   0   0\n",
       "4   0   1   0   0  45   0   0   1   0   1\n",
       "5   0   0   0   1   0  44   0   0   1   1\n",
       "6   0   1   0   0   0   0  43   0   2   0\n",
       "7   0   0   0   1   0   0   0  44   0   2\n",
       "8   0   2   0   0   0   0   1   0  38   1\n",
       "9   0   0   0   0   0   2   0   0   0  40"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(\n",
    "    (confusion_matrix(y_pred, y_test)),\n",
    "    columns=range(10),\n",
    "    index=range(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        45\n",
      "           1       0.89      0.93      0.91        44\n",
      "           2       1.00      1.00      1.00        44\n",
      "           3       0.96      0.98      0.97        45\n",
      "           4       1.00      0.94      0.97        48\n",
      "           5       0.96      0.94      0.95        47\n",
      "           6       0.96      0.93      0.95        46\n",
      "           7       0.98      0.94      0.96        47\n",
      "           8       0.88      0.90      0.89        42\n",
      "           9       0.89      0.95      0.92        42\n",
      "\n",
      "    accuracy                           0.95       450\n",
      "   macro avg       0.95      0.95      0.95       450\n",
      "weighted avg       0.95      0.95      0.95       450\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_pred, y_test))#打印分类报告"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 交叉验证\n",
    "\n",
    "分割数据对于评估统计模型性能是必要的。 但是，它减少了可用于学习模型的样本数量。 因此，应尽可能使用交叉验证。有多个拆分也会提供有关模型稳定性的信息。\n",
    "\n",
    "`scikit-learn`提供了三个函数：`cross_val_score`，`cross_val_predict`和`cross_validate`。 后者提供了有关拟合时间，训练和测试分数的更多信息。 我也可以一次返回多个分数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import cross_validate  # 导入cross_validate函数\n",
    "\n",
    "clf = LogisticRegression(solver='lbfgs',\n",
    "                         multi_class='auto',\n",
    "                         max_iter=1000,\n",
    "                         random_state=42)\n",
    "# 创建逻辑回归模型clf，设置相关参数\n",
    "scores = cross_validate(clf,\n",
    "                        X_train_scaled,\n",
    "                        y_train,\n",
    "                        cv=3,\n",
    "                        return_train_score=True)\n",
    "# 使用交叉验证评估模型clf，并将结果保存在变量scores中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': 1.0,\n",
       " 'class_weight': None,\n",
       " 'dual': False,\n",
       " 'fit_intercept': True,\n",
       " 'intercept_scaling': 1,\n",
       " 'l1_ratio': None,\n",
       " 'max_iter': 1000,\n",
       " 'multi_class': 'auto',\n",
       " 'n_jobs': None,\n",
       " 'penalty': 'l2',\n",
       " 'random_state': 42,\n",
       " 'solver': 'lbfgs',\n",
       " 'tol': 0.0001,\n",
       " 'verbose': 0,\n",
       " 'warm_start': False}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.get_params()\n",
    "# 获取模型的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fit_time</th>\n",
       "      <th>score_time</th>\n",
       "      <th>test_score</th>\n",
       "      <th>train_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.453863</td>\n",
       "      <td>0.001406</td>\n",
       "      <td>0.975501</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.188371</td>\n",
       "      <td>0.001142</td>\n",
       "      <td>0.957684</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.538949</td>\n",
       "      <td>0.001326</td>\n",
       "      <td>0.962138</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   fit_time  score_time  test_score  train_score\n",
       "0  0.453863    0.001406    0.975501          1.0\n",
       "1  0.188371    0.001142    0.957684          1.0\n",
       "2  0.538949    0.001326    0.962138          1.0"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df_scores = pd.DataFrame(scores)\n",
    "df_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 网格搜索调参\n",
    "可以通过穷举搜索来优化超参数。`GridSearchCV`提供此类实用程序，并通过参数网格进行交叉验证的网格搜索。\n",
    "\n",
    "如下例子，我们希望优化`LogisticRegression`分类器的`C`和`penalty`参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3,\n",
       "             estimator=LogisticRegression(max_iter=5000, random_state=42,\n",
       "                                          solver=&#x27;saga&#x27;),\n",
       "             n_jobs=-1,\n",
       "             param_grid=[{&#x27;C&#x27;: [0.01, 0.1, 1, 10], &#x27;penalty&#x27;: [&#x27;l2&#x27;, &#x27;l1&#x27;]}],\n",
       "             return_train_score=True)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3,\n",
       "             estimator=LogisticRegression(max_iter=5000, random_state=42,\n",
       "                                          solver=&#x27;saga&#x27;),\n",
       "             n_jobs=-1,\n",
       "             param_grid=[{&#x27;C&#x27;: [0.01, 0.1, 1, 10], &#x27;penalty&#x27;: [&#x27;l2&#x27;, &#x27;l1&#x27;]}],\n",
       "             return_train_score=True)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(max_iter=5000, random_state=42, solver=&#x27;saga&#x27;)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(max_iter=5000, random_state=42, solver=&#x27;saga&#x27;)</pre></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "GridSearchCV(cv=3,\n",
       "             estimator=LogisticRegression(max_iter=5000, random_state=42,\n",
       "                                          solver='saga'),\n",
       "             n_jobs=-1,\n",
       "             param_grid=[{'C': [0.01, 0.1, 1, 10], 'penalty': ['l2', 'l1']}],\n",
       "             return_train_score=True)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "clf = LogisticRegression(solver='saga',\n",
    "                         multi_class='auto',\n",
    "                         random_state=42,\n",
    "                         max_iter=5000)\n",
    "# 创建一个LogisticRegression对象，使用solver参数设置求解器为'saga'，\n",
    "# multi_class参数设置为'auto'，random_state参数设置随机种子为42，\n",
    "# max_iter参数设置最大迭代次数为5000\n",
    "param_grid = {\n",
    "    'logisticregression__C': [0.01, 0.1, 1],\n",
    "    'logisticregression__penalty': ['l2', 'l1']\n",
    "}\n",
    "# 定义一个参数网格，包含两个参数：'logisticregression__C'和'logisticregression__penalty'，\n",
    "# 分别对应逻辑回归模型的正则化系数C和惩罚项penalty\n",
    "tuned_parameters = [{\n",
    "    'C': [0.01, 0.1, 1, 10],\n",
    "    'penalty': ['l2', 'l1'],\n",
    "}]\n",
    "# 定义一个调优参数列表，其中包含两个字典，分别对应逻辑回归模型的正则化系数C和惩罚项penalty\n",
    "grid = GridSearchCV(clf,\n",
    "                    tuned_parameters,\n",
    "                    cv=3,\n",
    "                    n_jobs=-1,\n",
    "                    return_train_score=True)\n",
    "# 创建一个GridSearchCV对象，传入待调优的模型clf、调优参数列表tuned_parameters，\n",
    "# cv参数设置交叉验证的折数为3，n_jobs参数设置并行计算的作业数为-1（使用所有可用的处理器），\n",
    "# return_train_score参数设置是否返回训练集上的得分\n",
    "grid.fit(X_train_scaled, y_train)\n",
    "# 使用fit方法对训练集X_train_scaled和对应的标签y_train进行网格搜索调优"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以使用`get_params()`检查管道的所有参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'cv': 3,\n",
       " 'error_score': nan,\n",
       " 'estimator__C': 1.0,\n",
       " 'estimator__class_weight': None,\n",
       " 'estimator__dual': False,\n",
       " 'estimator__fit_intercept': True,\n",
       " 'estimator__intercept_scaling': 1,\n",
       " 'estimator__l1_ratio': None,\n",
       " 'estimator__max_iter': 5000,\n",
       " 'estimator__multi_class': 'auto',\n",
       " 'estimator__n_jobs': None,\n",
       " 'estimator__penalty': 'l2',\n",
       " 'estimator__random_state': 42,\n",
       " 'estimator__solver': 'saga',\n",
       " 'estimator__tol': 0.0001,\n",
       " 'estimator__verbose': 0,\n",
       " 'estimator__warm_start': False,\n",
       " 'estimator': LogisticRegression(max_iter=5000, random_state=42, solver='saga'),\n",
       " 'n_jobs': -1,\n",
       " 'param_grid': [{'C': [0.01, 0.1, 1, 10], 'penalty': ['l2', 'l1']}],\n",
       " 'pre_dispatch': '2*n_jobs',\n",
       " 'refit': True,\n",
       " 'return_train_score': True,\n",
       " 'scoring': None,\n",
       " 'verbose': 0}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid.get_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean_fit_time</th>\n",
       "      <th>std_fit_time</th>\n",
       "      <th>mean_score_time</th>\n",
       "      <th>std_score_time</th>\n",
       "      <th>param_C</th>\n",
       "      <th>param_penalty</th>\n",
       "      <th>params</th>\n",
       "      <th>split0_test_score</th>\n",
       "      <th>split1_test_score</th>\n",
       "      <th>split2_test_score</th>\n",
       "      <th>mean_test_score</th>\n",
       "      <th>std_test_score</th>\n",
       "      <th>rank_test_score</th>\n",
       "      <th>split0_train_score</th>\n",
       "      <th>split1_train_score</th>\n",
       "      <th>split2_train_score</th>\n",
       "      <th>mean_train_score</th>\n",
       "      <th>std_train_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.682172</td>\n",
       "      <td>0.032064</td>\n",
       "      <td>0.008335</td>\n",
       "      <td>0.000471</td>\n",
       "      <td>0.01</td>\n",
       "      <td>l2</td>\n",
       "      <td>{'C': 0.01, 'penalty': 'l2'}</td>\n",
       "      <td>0.953229</td>\n",
       "      <td>0.930958</td>\n",
       "      <td>0.935412</td>\n",
       "      <td>0.939866</td>\n",
       "      <td>0.009622</td>\n",
       "      <td>7</td>\n",
       "      <td>0.954343</td>\n",
       "      <td>0.953229</td>\n",
       "      <td>0.952116</td>\n",
       "      <td>0.953229</td>\n",
       "      <td>0.000909</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.924438</td>\n",
       "      <td>0.257020</td>\n",
       "      <td>0.007666</td>\n",
       "      <td>0.000474</td>\n",
       "      <td>0.01</td>\n",
       "      <td>l1</td>\n",
       "      <td>{'C': 0.01, 'penalty': 'l1'}</td>\n",
       "      <td>0.543430</td>\n",
       "      <td>0.563474</td>\n",
       "      <td>0.514477</td>\n",
       "      <td>0.540460</td>\n",
       "      <td>0.020113</td>\n",
       "      <td>8</td>\n",
       "      <td>0.550111</td>\n",
       "      <td>0.536748</td>\n",
       "      <td>0.532294</td>\n",
       "      <td>0.539718</td>\n",
       "      <td>0.007571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.901998</td>\n",
       "      <td>0.220747</td>\n",
       "      <td>0.008333</td>\n",
       "      <td>0.000945</td>\n",
       "      <td>0.1</td>\n",
       "      <td>l2</td>\n",
       "      <td>{'C': 0.1, 'penalty': 'l2'}</td>\n",
       "      <td>0.977728</td>\n",
       "      <td>0.953229</td>\n",
       "      <td>0.962138</td>\n",
       "      <td>0.964365</td>\n",
       "      <td>0.010125</td>\n",
       "      <td>4</td>\n",
       "      <td>0.986637</td>\n",
       "      <td>0.989978</td>\n",
       "      <td>0.987751</td>\n",
       "      <td>0.988122</td>\n",
       "      <td>0.001389</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5.114192</td>\n",
       "      <td>0.610863</td>\n",
       "      <td>0.007060</td>\n",
       "      <td>0.001575</td>\n",
       "      <td>0.1</td>\n",
       "      <td>l1</td>\n",
       "      <td>{'C': 0.1, 'penalty': 'l1'}</td>\n",
       "      <td>0.953229</td>\n",
       "      <td>0.935412</td>\n",
       "      <td>0.942094</td>\n",
       "      <td>0.943578</td>\n",
       "      <td>0.007349</td>\n",
       "      <td>6</td>\n",
       "      <td>0.951002</td>\n",
       "      <td>0.961024</td>\n",
       "      <td>0.957684</td>\n",
       "      <td>0.956570</td>\n",
       "      <td>0.004167</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.702494</td>\n",
       "      <td>0.558653</td>\n",
       "      <td>0.007889</td>\n",
       "      <td>0.000685</td>\n",
       "      <td>1</td>\n",
       "      <td>l2</td>\n",
       "      <td>{'C': 1, 'penalty': 'l2'}</td>\n",
       "      <td>0.977728</td>\n",
       "      <td>0.957684</td>\n",
       "      <td>0.962138</td>\n",
       "      <td>0.965850</td>\n",
       "      <td>0.008594</td>\n",
       "      <td>1</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>17.438481</td>\n",
       "      <td>1.551526</td>\n",
       "      <td>0.003374</td>\n",
       "      <td>0.003357</td>\n",
       "      <td>1</td>\n",
       "      <td>l1</td>\n",
       "      <td>{'C': 1, 'penalty': 'l1'}</td>\n",
       "      <td>0.977728</td>\n",
       "      <td>0.959911</td>\n",
       "      <td>0.953229</td>\n",
       "      <td>0.963623</td>\n",
       "      <td>0.010340</td>\n",
       "      <td>5</td>\n",
       "      <td>0.998886</td>\n",
       "      <td>0.998886</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.999258</td>\n",
       "      <td>0.000525</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>9.172883</td>\n",
       "      <td>0.664091</td>\n",
       "      <td>0.001333</td>\n",
       "      <td>0.000472</td>\n",
       "      <td>10</td>\n",
       "      <td>l2</td>\n",
       "      <td>{'C': 10, 'penalty': 'l2'}</td>\n",
       "      <td>0.975501</td>\n",
       "      <td>0.966592</td>\n",
       "      <td>0.955457</td>\n",
       "      <td>0.965850</td>\n",
       "      <td>0.008200</td>\n",
       "      <td>1</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>20.035510</td>\n",
       "      <td>0.830658</td>\n",
       "      <td>0.001415</td>\n",
       "      <td>0.000430</td>\n",
       "      <td>10</td>\n",
       "      <td>l1</td>\n",
       "      <td>{'C': 10, 'penalty': 'l1'}</td>\n",
       "      <td>0.979955</td>\n",
       "      <td>0.966592</td>\n",
       "      <td>0.951002</td>\n",
       "      <td>0.965850</td>\n",
       "      <td>0.011832</td>\n",
       "      <td>1</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   mean_fit_time  std_fit_time  mean_score_time  std_score_time param_C  \\\n",
       "0       0.682172      0.032064         0.008335        0.000471    0.01   \n",
       "1       0.924438      0.257020         0.007666        0.000474    0.01   \n",
       "2       1.901998      0.220747         0.008333        0.000945     0.1   \n",
       "3       5.114192      0.610863         0.007060        0.001575     0.1   \n",
       "4       3.702494      0.558653         0.007889        0.000685       1   \n",
       "5      17.438481      1.551526         0.003374        0.003357       1   \n",
       "6       9.172883      0.664091         0.001333        0.000472      10   \n",
       "7      20.035510      0.830658         0.001415        0.000430      10   \n",
       "\n",
       "  param_penalty                        params  split0_test_score  \\\n",
       "0            l2  {'C': 0.01, 'penalty': 'l2'}           0.953229   \n",
       "1            l1  {'C': 0.01, 'penalty': 'l1'}           0.543430   \n",
       "2            l2   {'C': 0.1, 'penalty': 'l2'}           0.977728   \n",
       "3            l1   {'C': 0.1, 'penalty': 'l1'}           0.953229   \n",
       "4            l2     {'C': 1, 'penalty': 'l2'}           0.977728   \n",
       "5            l1     {'C': 1, 'penalty': 'l1'}           0.977728   \n",
       "6            l2    {'C': 10, 'penalty': 'l2'}           0.975501   \n",
       "7            l1    {'C': 10, 'penalty': 'l1'}           0.979955   \n",
       "\n",
       "   split1_test_score  split2_test_score  mean_test_score  std_test_score  \\\n",
       "0           0.930958           0.935412         0.939866        0.009622   \n",
       "1           0.563474           0.514477         0.540460        0.020113   \n",
       "2           0.953229           0.962138         0.964365        0.010125   \n",
       "3           0.935412           0.942094         0.943578        0.007349   \n",
       "4           0.957684           0.962138         0.965850        0.008594   \n",
       "5           0.959911           0.953229         0.963623        0.010340   \n",
       "6           0.966592           0.955457         0.965850        0.008200   \n",
       "7           0.966592           0.951002         0.965850        0.011832   \n",
       "\n",
       "   rank_test_score  split0_train_score  split1_train_score  \\\n",
       "0                7            0.954343            0.953229   \n",
       "1                8            0.550111            0.536748   \n",
       "2                4            0.986637            0.989978   \n",
       "3                6            0.951002            0.961024   \n",
       "4                1            1.000000            1.000000   \n",
       "5                5            0.998886            0.998886   \n",
       "6                1            1.000000            1.000000   \n",
       "7                1            1.000000            1.000000   \n",
       "\n",
       "   split2_train_score  mean_train_score  std_train_score  \n",
       "0            0.952116          0.953229         0.000909  \n",
       "1            0.532294          0.539718         0.007571  \n",
       "2            0.987751          0.988122         0.001389  \n",
       "3            0.957684          0.956570         0.004167  \n",
       "4            1.000000          1.000000         0.000000  \n",
       "5            1.000000          0.999258         0.000525  \n",
       "6            1.000000          1.000000         0.000000  \n",
       "7            1.000000          1.000000         0.000000  "
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_grid = pd.DataFrame(grid.cv_results_)\n",
    "df_grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 流水线操作"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`scikit-learn`引入了`Pipeline`对象。它依次连接多个转换器和分类器（或回归器）。我们可以创建一个如下管道："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: >"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAq+klEQVR4nO3de1zVdYL/8Tc3OQcIHRVREEWwQq0gMBP9ZbkpJE1rbhenaVekjUknpzVmcsU1r2s8mknSh5l2Ga1htx1n0pymzCC2m2VeUFt9KJr3wgve8igIHuDz+6P17JAgd/kIr+fj4cPO93zOh8/X4xdefb9fxMsYYwQAAGAx79ZeAAAAQF0IFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADW823tBTSXqqoqHTlyRNddd528vLxaezkAAKAejDE6d+6cwsLC5O1d+3mUNhMsR44cUURERGsvAwAANMK3336rnj171vp8mwmW6667TtIPOxwcHNzKq0FLc7vdys3NVVJSkvz8/Fp7OQCaEcd3++JyuRQREeH5Ol6bNhMsly4DBQcHEyztgNvtVkBAgIKDg/mEBrQxHN/tU123c3DTLQAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALBeg4Pls88+03333aewsDB5eXlp9erVdb7mk08+UXx8vPz9/dW3b1+98cYbl41ZvHixIiMj5XA4dPvtt2vjxo0NXRoAAGijGhwsJSUlio2N1eLFi+s1/sCBA7r33ns1fPhwbdu2TZMnT9bjjz+uDz/80DNmxYoVysjI0MyZM7VlyxbFxsYqOTlZxcXFDV0eAABogxr8s4RGjRqlUaNG1Xv80qVL1adPH82fP1+S1K9fP61bt04vvviikpOTJUnZ2dlKT09XWlqa5zXvv/++li1bpqlTpzZ0iQAAoI1p8R9+uH79eo0YMaLatuTkZE2ePFmSdPHiRRUUFCgzM9PzvLe3t0aMGKH169fXOm95ebnKy8s9j10ul6QffmiW2+1uxj1Ac7lQcUFffrtLZRerah1TVl6mo98drnOuyooq7d37jb6pOC0f3yufKOzRs5cc/o5an3d08NaQiH5y+jrr/LgAWt6lz+F8Lm8f6vs+t3iwHDt2TKGhodW2hYaGyuVy6cKFCzpz5owqKytrHFNYWFjrvFlZWZo9e/Zl23NzcxUQENA8i0ez2lFyRH90v9x8E3aXvjhfj3Hf1z3kZ5t/qZsCw5q6IgDNKC8vr7WXgKugtLS0XuNaPFhaSmZmpjIyMjyPXS6XIiIilJSUpODg4FZcGWrT/dti/f4PPnr67r6K+EnNZzMaeoalb9/rm3SG5dszF/Ri/l6NGHev4iO61b0TAFqc2+1WXl6eRo4cKT8/v9ZeDlrYpSskdWnxYOnevbuOHz9ebdvx48cVHBwsp9MpHx8f+fj41Dime/futc7r7+8vf3//y7b7+fnxF9xSgf5BqioL1/CoBN0U3rH2gQl1z+V2u7VmzRqlpKQ06f3eUXRW89+/oED/IP7eAJbh83n7UN/3uMX/HZbExETl5+dX25aXl6fExERJUocOHZSQkFBtTFVVlfLz8z1jAABA+9bgYDl//ry2bdumbdu2Sfrh25a3bdumw4d/OI2fmZmpcePGecZPmDBB+/fv15QpU1RYWKiXX35Zf/rTn/T00097xmRkZOi1117Tm2++qV27dmnixIkqKSnxfNcQAABo3xp8SWjz5s0aPny45/Gl+0hSU1P1xhtv6OjRo554kaQ+ffro/fff19NPP62FCxeqZ8+eev311z3f0ixJY8eO1YkTJzRjxgwdO3ZMcXFxWrt27WU34gIAgPapwcFy1113yRhT6/M1/Su2d911l7Zu3XrFeSdNmqRJkyY1dDkAAKAd4GcJAQAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOv5tvYC0H5ccFdKknYUnW3yXCUXyrX5hNT90BkFOv0bPc/e4vNNXgsAoOURLLhq9v1vHExdtb2ZZvRVzt5NzTJToD+HAgDYjM/SuGqSBnSXJEV3C5LTz6dJc+0+ela/fnu75j94s27s0bFJcwX6+6pP18AmzQEAaFkEC66azoEd9LNBvZplroqKCklSdEigbgpvWrAAAOzHTbcAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6zUqWBYvXqzIyEg5HA7dfvvt2rhxY61j3W635syZo+joaDkcDsXGxmrt2rXVxpw7d06TJ09W79695XQ6NWTIEG3atKkxSwMAAG1Qg4NlxYoVysjI0MyZM7VlyxbFxsYqOTlZxcXFNY6fPn26XnnlFS1atEg7d+7UhAkTNGbMGG3dutUz5vHHH1deXp5ycnK0fft2JSUlacSIESoqKmr8ngEAgDajwcGSnZ2t9PR0paWlqX///lq6dKkCAgK0bNmyGsfn5ORo2rRpSklJUVRUlCZOnKiUlBTNnz9fknThwgWtXLlSv/3tbzVs2DD17dtXs2bNUt++fbVkyZKm7R0AAGgTfBsy+OLFiyooKFBmZqZnm7e3t0aMGKH169fX+Jry8nI5HI5q25xOp9atWydJqqioUGVl5RXH1DZveXm557HL5ZL0wyUot9vdkN3CNaiiosLzO+830LZcOqY5ttuH+r7PDQqWkydPqrKyUqGhodW2h4aGqrCwsMbXJCcnKzs7W8OGDVN0dLTy8/O1atUqVVZWSpKuu+46JSYmau7cuerXr59CQ0P1X//1X1q/fr369u1b61qysrI0e/bsy7bn5uYqICCgIbuFa9C35yXJV1999ZWKdrT2agC0hLy8vNZeAq6C0tLSeo1rULA0xsKFC5Wenq6YmBh5eXkpOjpaaWlp1S4h5eTk6LHHHlN4eLh8fHwUHx+vRx55RAUFBbXOm5mZqYyMDM9jl8uliIgIJSUlKTg4uEX3Ca3v68Onpe2bNXjwYMX26tzaywHQjNxut/Ly8jRy5Ej5+fm19nLQwi5dIalLg4Kla9eu8vHx0fHjx6ttP378uLp3717ja0JCQrR69WqVlZXp1KlTCgsL09SpUxUVFeUZEx0drU8//VQlJSVyuVzq0aOHxo4dW23Mj/n7+8vf3/+y7X5+fvwFbwd8fX09v/N+A20Tn8/bh/q+xw266bZDhw5KSEhQfn6+Z1tVVZXy8/OVmJh4xdc6HA6Fh4eroqJCK1eu1OjRoy8bExgYqB49eujMmTP68MMPaxwDAADanwZfEsrIyFBqaqoGDhyoQYMGacGCBSopKVFaWpokady4cQoPD1dWVpYkacOGDSoqKlJcXJyKioo0a9YsVVVVacqUKZ45P/zwQxljdOONN2rv3r165plnFBMT45kTAAC0bw0OlrFjx+rEiROaMWOGjh07pri4OK1du9ZzI+7hw4fl7f1/J27Kyso0ffp07d+/X0FBQUpJSVFOTo46derkGXP27FllZmbqu+++U+fOnfXAAw9o3rx5nAoEAACSJC9jjGntRTQHl8uljh076uzZs9x02w5sO3RK9y/5SqsnDlZc7y6tvRwAzcjtdmvNmjVKSUnhf1zbgfp+/eZnCQEAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAer6tvQDgb5WWlqqwsLDOcbuPfq/yY3u1a4dTVac6XXFsTEyMAgICmmmFAIDWQLDAKoWFhUpISKj3+J+/WfeYgoICxcfHN2FVAIDWRrDAKjExMSooKKhz3PkL5Xr/4/W6d3iigpz+dc4JALi2ESywSkBAQL3Ohrjdbp05WazEQQPl5+d3FVYGAGhN3HQLAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6zUqWBYvXqzIyEg5HA7dfvvt2rhxY61j3W635syZo+joaDkcDsXGxmrt2rXVxlRWVurZZ59Vnz595HQ6FR0drblz58oY05jlAQCANqbBwbJixQplZGRo5syZ2rJli2JjY5WcnKzi4uIax0+fPl2vvPKKFi1apJ07d2rChAkaM2aMtm7d6hnz/PPPa8mSJXrppZe0a9cuPf/88/rtb3+rRYsWNX7PAABAm9HgYMnOzlZ6errS0tLUv39/LV26VAEBAVq2bFmN43NycjRt2jSlpKQoKipKEydOVEpKiubPn+8Z8+WXX2r06NG69957FRkZqQcffFBJSUlXPHMDAADaD9+GDL548aIKCgqUmZnp2ebt7a0RI0Zo/fr1Nb6mvLxcDoej2jan06l169Z5Hg8ZMkSvvvqq9uzZoxtuuEFff/211q1bp+zs7FrXUl5ervLycs9jl8sl6YdLUG63uyG7hWvQpfeY9xpoezi+25f6vs8NCpaTJ0+qsrJSoaGh1baHhoaqsLCwxtckJycrOztbw4YNU3R0tPLz87Vq1SpVVlZ6xkydOlUul0sxMTHy8fFRZWWl5s2bp0cffbTWtWRlZWn27NmXbc/NzVVAQEBDdgvXsLy8vNZeAoAWwvHdPpSWltZrXIOCpTEWLlyo9PR0xcTEyMvLS9HR0UpLS6t2CelPf/qT/vM//1NvvfWWBgwYoG3btmny5MkKCwtTampqjfNmZmYqIyPD89jlcikiIkJJSUkKDg5u6d1CK3O73crLy9PIkSPl5+fX2ssB0Iw4vtuXS1dI6tKgYOnatat8fHx0/PjxatuPHz+u7t271/iakJAQrV69WmVlZTp16pTCwsI0depURUVFecY888wzmjp1qn72s59Jkm6++WYdOnRIWVlZtQaLv7+//P39L9vu5+fHX/B2hPcbaLs4vtuH+r7HDbrptkOHDkpISFB+fr5nW1VVlfLz85WYmHjF1zocDoWHh6uiokIrV67U6NGjPc+VlpbK27v6Unx8fFRVVdWQ5QEAgDaqwZeEMjIylJqaqoEDB2rQoEFasGCBSkpKlJaWJkkaN26cwsPDlZWVJUnasGGDioqKFBcXp6KiIs2aNUtVVVWaMmWKZ8777rtP8+bNU69evTRgwABt3bpV2dnZeuyxx5ppNwEAwLWswcEyduxYnThxQjNmzNCxY8cUFxentWvXem7EPXz4cLWzJWVlZZo+fbr279+voKAgpaSkKCcnR506dfKMWbRokZ599ln98pe/VHFxscLCwvTEE09oxowZTd9DAABwzfMybeSfk3W5XOrYsaPOnj3LTbftgNvt1po1a5SSksI1bqCN4fhuX+r79ZufJQQAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADr+bb2AgAA7UNpaakKCwvrHHf+Qrm+3L5PP+m6WUFO/yuOjYmJUUBAQHMtERYjWAAAV0VhYaESEhLqPf639RhTUFCg+Pj4xi8K1wyCBQBwVcTExKigoKDOcbuPfq+MP29X9kM368YeneqcE+0DwQIAuCoCAgLqdTbE+9Ap+X9+Qf1uilVc7y5XYWW4FnDTLQAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrNSpYFi9erMjISDkcDt1+++3auHFjrWPdbrfmzJmj6OhoORwOxcbGau3atdXGREZGysvL67JfTz75ZGOWBwAA2pgGB8uKFSuUkZGhmTNnasuWLYqNjVVycrKKi4trHD99+nS98sorWrRokXbu3KkJEyZozJgx2rp1q2fMpk2bdPToUc+vvLw8SdJDDz3UyN0CAABtSYODJTs7W+np6UpLS1P//v21dOlSBQQEaNmyZTWOz8nJ0bRp05SSkqKoqChNnDhRKSkpmj9/vmdMSEiIunfv7vn13nvvKTo6WnfeeWfj9wwAALQZvg0ZfPHiRRUUFCgzM9OzzdvbWyNGjND69etrfE15ebkcDke1bU6nU+vWrav1Y/zHf/yHMjIy5OXlVetaysvLVV5e7nnscrkk/XAJyu1213ufcG269B7zXgNtT0VFhed3jvG2r77vcYOC5eTJk6qsrFRoaGi17aGhoSosLKzxNcnJycrOztawYcMUHR2t/Px8rVq1SpWVlTWOX716tb7//nuNHz/+imvJysrS7NmzL9uem5urgICA+u0QrnmXLh8CaDu+PS9Jvvrqq69UtKO1V4OWVlpaWq9xDQqWxli4cKHS09MVExMjLy8vRUdHKy0trdZLSL///e81atQohYWFXXHezMxMZWRkeB67XC5FREQoKSlJwcHBzboPsI/b7VZeXp5GjhwpPz+/1l4OgGb09eHT0vbNGjx4sGJ7dW7t5aCFXbpCUpcGBUvXrl3l4+Oj48ePV9t+/Phxde/evcbXhISEaPXq1SorK9OpU6cUFhamqVOnKioq6rKxhw4d0kcffaRVq1bVuRZ/f3/5+/tftt3Pz48vYO0I7zfQ9vj6+np+5/hu++r7HjfoptsOHTooISFB+fn5nm1VVVXKz89XYmLiFV/rcDgUHh6uiooKrVy5UqNHj75szPLly9WtWzfde++9DVkWAABo4xp8SSgjI0OpqakaOHCgBg0apAULFqikpERpaWmSpHHjxik8PFxZWVmSpA0bNqioqEhxcXEqKirSrFmzVFVVpSlTplSbt6qqSsuXL1dqaqqnrgEAAKRGBMvYsWN14sQJzZgxQ8eOHVNcXJzWrl3ruRH38OHD8vb+vxM3ZWVlmj59uvbv36+goCClpKQoJydHnTp1qjbvRx99pMOHD+uxxx5r2h4BAIA2p1GnMiZNmqRJkybV+Nwnn3xS7fGdd96pnTt31jlnUlKSjDGNWQ4AAGjj+FlCAADAetwsAgBoFgdOlqikvKLJ8+w7UeL5van3NAb6+6pP18Amrwmtj2ABADTZgZMlGv7CJ80656/f3t4s83z8m7uIljaAYAEANNmlMysLxsapb7egps11oVzvfbJeP70rUYHOy/+9rfraW3xek1dsa5azPmh9BAsAoNn07Rakm8I7NmkOt9utYyFSfO+f8A/HwYObbgEAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYz7e1FwAAuPaVV5bJ21GkA67d8nYENWmuiooKHak4ol2nd8nXt/Ffpg64zsvbUaTyyjJJHZu0JrQ+ggUA0GRHSg4psM8iTdvYfHO+vPblJs8R2Ec6UhKnBIU2w4rQmggWAECThQX2VsmBX2nh2DhFd2v6GZYv1n2hof9vaJPOsOwrPq9/WbFNYcN7N2k9sAPBAgBoMn8fh6rKwtUn+Eb179K0yy9ut1sHfA+oX+d+8vPza/Q8VWVnVVV2Qv4+jiatB3bgplsAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFiPYAEAANYjWAAAgPUIFgAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADW823tBQAArn0X3JWSpB1FZ5s8V8mFcm0+IXU/dEaBTv9Gz7O3+HyT1wJ7ECwAgCbb979xMHXV9maa0Vc5ezc1y0yB/nypawt4FwEATZY0oLskKbpbkJx+Pk2aa/fRs/r129s1/8GbdWOPjk2aK9DfV326BjZpDtiBYAEANFnnwA762aBezTJXRUWFJCk6JFA3hTctWNB2cNMtAACwXqOCZfHixYqMjJTD4dDtt9+ujRs31jrW7XZrzpw5io6OlsPhUGxsrNauXXvZuKKiIv3jP/6junTpIqfTqZtvvlmbN29uzPIAAEAb0+BgWbFihTIyMjRz5kxt2bJFsbGxSk5OVnFxcY3jp0+frldeeUWLFi3Szp07NWHCBI0ZM0Zbt271jDlz5oyGDh0qPz8/ffDBB9q5c6fmz5+vn/zkJ43fMwAA0GY0OFiys7OVnp6utLQ09e/fX0uXLlVAQICWLVtW4/icnBxNmzZNKSkpioqK0sSJE5WSkqL58+d7xjz//POKiIjQ8uXLNWjQIPXp00dJSUmKjo5u/J4BAIA2o0E33V68eFEFBQXKzMz0bPP29taIESO0fv36Gl9TXl4uh8NRbZvT6dS6des8j999910lJyfroYce0qeffqrw8HD98pe/VHp6eq1rKS8vV3l5ueexy+WS9MMlKLfb3ZDdwjXo0nvMew20PZduuq2oqOAYbwfq+x43KFhOnjypyspKhYaGVtseGhqqwsLCGl+TnJys7OxsDRs2TNHR0crPz9eqVatUWVnpGbN//34tWbJEGRkZmjZtmjZt2qSnnnpKHTp0UGpqao3zZmVlafbs2Zdtz83NVUBAQEN2C9ewvLy81l4CgGb27XlJ8tVXX32loh2tvRq0tNLS0nqNa/Fva164cKHS09MVExMjLy8vRUdHKy0trdolpKqqKg0cOFDPPfecJOnWW2/Vjh07tHTp0lqDJTMzUxkZGZ7HLpdLERERSkpKUnBwcMvuFFqd2+1WXl6eRo4cKT8/v9ZeDoBm9PXh09L2zRo8eLBie3Vu7eWghV26QlKXBgVL165d5ePjo+PHj1fbfvz4cXXv3r3G14SEhGj16tUqKyvTqVOnFBYWpqlTpyoqKsozpkePHurfv3+11/Xr108rV66sdS3+/v7y97/8n2z28/PjC1g7wvsNtD2+vr6e3zm+2776vscNuum2Q4cOSkhIUH5+vmdbVVWV8vPzlZiYeMXXOhwOhYeHq6KiQitXrtTo0aM9zw0dOlS7d++uNn7Pnj3q3bt3Q5YHAADaqAZfEsrIyFBqaqoGDhyoQYMGacGCBSopKVFaWpokady4cQoPD1dWVpYkacOGDSoqKlJcXJyKioo0a9YsVVVVacqUKZ45n376aQ0ZMkTPPfecHn74YW3cuFGvvvqqXn311WbaTQAAcC1rcLCMHTtWJ06c0IwZM3Ts2DHFxcVp7dq1nhtxDx8+LG/v/ztxU1ZWpunTp2v//v0KCgpSSkqKcnJy1KlTJ8+Y2267Te+8844yMzM1Z84c9enTRwsWLNCjjz7a9D0EAADXPC9jjGntRTQHl8uljh076uzZs9x02w643W6tWbNGKSkpXOMG2phth07p/iVfafXEwYrr3aW1l4MWVt+v3/wsIQAAYD2CBQAAWI9gAQAA1iNYAACA9QgWAABgPYIFAABYj2ABAADWI1gAAID1CBYAAGA9ggUAAFivwT9LCACAxigtLVVhYWGd43Yf/V7lx/Zq1w6nqk51uuLYmJgYBQQENNMKYTOCBQBwVRQWFiohIaHe43/+Zt1jCgoKFB8f34RV4VpBsAAAroqYmBgVFBTUOe78hXK9//F63Ts8UUFO/zrnRPtAsAAAroqAgIB6nQ1xu906c7JYiYMG8tPY4cFNtwAAwHoECwAAsB7BAgAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALAewQIAAKxHsAAAAOsRLAAAwHpt5qc1G2MkSS6Xq5VXgqvB7XartLRULpeLn+YKtDEc3+3Lpa/bl76O16bNBMu5c+ckSREREa28EgAA0FDnzp1Tx44da33ey9SVNNeIqqoqHTlyRNddd528vLxaezloYS6XSxEREfr2228VHBzc2ssB0Iw4vtsXY4zOnTunsLAweXvXfqdKmznD4u3trZ49e7b2MnCVBQcH8wkNaKM4vtuPK51ZuYSbbgEAgPUIFgAAYD2CBdckf39/zZw5U/7+/q29FADNjOMbNWkzN90CAIC2izMsAADAegQLAACwHsECAACsR7DgqouMjNSCBQtaexkAgGsIwYJ6ueuuuzR58uRmmWvTpk36xS9+0SxzAWhezXmsS9L48eN1//33N9t8aL8IFjQLY4wqKirqNTYkJEQBAQEtvKKW4Xa7W3sJAJrRxYsXW3sJqCeCBXUaP368Pv30Uy1cuFBeXl7y8vLSG2+8IS8vL33wwQdKSEiQv7+/1q1bp3379mn06NEKDQ1VUFCQbrvtNn300UfV5vvxJSEvLy+9/vrrGjNmjAICAnT99dfr3Xffrdfazpw5o0cffVQhISFyOp26/vrrtXz5cs/z3333nR555BF17txZgYGBGjhwoDZs2OB5fsmSJYqOjlaHDh104403Kicnp9r8Xl5eWrJkif7+7/9egYGBmjdvniTpL3/5i+Lj4+VwOBQVFaXZs2fXO9gAW9V0rB88eFA7duzQqFGjFBQUpNDQUP3TP/2TTp486Xnd22+/rZtvvllOp1NdunTRiBEjVFJSolmzZunNN9/UX/7yF898n3zyyRXXcPHiRU2aNEk9evSQw+FQ7969lZWV5Xn++++/1xNPPKHQ0FA5HA7ddNNNeu+99zzPr1y5UgMGDJC/v78iIyM1f/78avNHRkZq7ty5GjdunIKDgz1ne9etW6c77rhDTqdTEREReuqpp1RSUtIMf6poNgaow/fff28SExNNenq6OXr0qDl69Kj56KOPjCRzyy23mNzcXLN3715z6tQps23bNrN06VKzfft2s2fPHjN9+nTjcDjMoUOHPPP17t3bvPjii57HkkzPnj3NW2+9Zb755hvz1FNPmaCgIHPq1Kk61/bkk0+auLg4s2nTJnPgwAGTl5dn3n33XWOMMefOnTNRUVHmjjvuMJ9//rn55ptvzIoVK8yXX35pjDFm1apVxs/PzyxevNjs3r3bzJ8/3/j4+Jj//u//rra2bt26mWXLlpl9+/aZQ4cOmc8++8wEBwebN954w+zbt8/k5uaayMhIM2vWrGb6EwdaR03H+smTJ01ISIjJzMw0u3btMlu2bDEjR440w4cPN8YYc+TIEePr62uys7PNgQMHzP/8z/+YxYsXm3Pnzplz586Zhx9+2Nxzzz2e+crLy6+4ht/97ncmIiLCfPbZZ+bgwYPm888/N2+99ZYxxpjKykozePBgM2DAAJObm2v27dtn/vrXv5o1a9YYY4zZvHmz8fb2NnPmzDG7d+82y5cvN06n0yxfvtwzf+/evU1wcLB54YUXzN69ez2/AgMDzYsvvmj27NljvvjiC3Prrbea8ePHt8wfNBqFYEG93HnnneZf/uVfPI8//vhjI8msXr26ztcOGDDALFq0yPO4pmCZPn265/H58+eNJPPBBx/UOfd9991n0tLSanzulVdeMdddd12t4TNkyBCTnp5ebdtDDz1kUlJSqq1t8uTJ1cbcfffd5rnnnqu2LScnx/To0aPO9QK2+/GxPnfuXJOUlFRtzLfffmskmd27d5uCggIjyRw8eLDG+VJTU83o0aPr/fF/9atfmb/7u78zVVVVlz334YcfGm9vb7N79+4aX/vzn//cjBw5stq2Z555xvTv39/zuHfv3ub++++vNuaf//mfzS9+8Ytq2z7//HPj7e1tLly4UO+1o2VxSQhNMnDgwGqPz58/r9/85jfq16+fOnXqpKCgIO3atUuHDx++4jy33HKL578DAwMVHBys4uLiOj/+xIkT9cc//lFxcXGaMmWKvvzyS89z27Zt06233qrOnTvX+Npdu3Zp6NCh1bYNHTpUu3btuuI+fv3115ozZ46CgoI8v9LT03X06FGVlpbWuWbgWvL111/r448/rvb3PSYmRpK0b98+xcbG6u6779bNN9+shx56SK+99prOnDnT6I83fvx4bdu2TTfeeKOeeuop5ebmep7btm2bevbsqRtuuKHG19Z2TH/zzTeqrKz0bKvpmH7jjTeq7WNycrKqqqp04MCBRu8Lmpdvay8A17bAwMBqj3/zm98oLy9PL7zwgvr27Sun06kHH3ywzhvb/Pz8qj328vJSVVVVnR9/1KhROnTokNasWaO8vDzdfffdevLJJ/XCCy/I6XQ2fIdq8ON9PH/+vGbPnq1/+Id/uGysw+Folo8J2OL8+fO677779Pzzz1/2XI8ePeTj46O8vDx9+eWXys3N1aJFi/Rv//Zv2rBhg/r06dPgjxcfH68DBw7ogw8+0EcffaSHH35YI0aM0Ntvv92ix/QTTzyhp5566rKxvXr1apaPiaYjWFAvHTp0qPZ/KLX54osvNH78eI0ZM0bSD58IDh482KJrCwkJUWpqqlJTU3XHHXfomWee0QsvvKBbbrlFr7/+uk6fPl3jWZZ+/frpiy++UGpqarX19+/f/4ofLz4+Xrt371bfvn2bfV+A1vbjYz0+Pl4rV65UZGSkfH1r/pLh5eWloUOHaujQoZoxY4Z69+6td955RxkZGfX+3PG3goODNXbsWI0dO1YPPvig7rnnHp0+fVq33HKLvvvuO+3Zs6fGsyyXjum/9cUXX+iGG26Qj49PrR8vPj5eO3fu5Ji2HMGCeomMjNSGDRt08OBBBQUF1Xr24/rrr9eqVat03333ycvLS88++2y9zpQ01owZM5SQkKABAwaovLxc7733nvr16ydJeuSRR/Tcc8/p/vvvV1ZWlnr06KGtW7cqLCxMiYmJeuaZZ/Twww/r1ltv1YgRI/TXv/5Vq1atuuy7mmr6mD/96U/Vq1cvPfjgg/L29tbXX3+tHTt26N///d9bbF+Bq+HHx/qTTz6p1157TY888oimTJmizp07a+/evfrjH/+o119/XZs3b1Z+fr6SkpLUrVs3bdiwQSdOnPAch5GRkfrwww+1e/dudenSRR07drzsjOrfys7OVo8ePXTrrbfK29tbf/7zn9W9e3d16tRJd955p4YNG6YHHnhA2dnZ6tu3rwoLC+Xl5aV77rlHv/71r3Xbbbdp7ty5Gjt2rNavX6+XXnpJL7/88hX3+V//9V81ePBgTZo0SY8//rgCAwO1c+dO5eXl6aWXXmrWP180QWvfRINrw+7du83gwYON0+k0kszy5cuNJHPmzJlq4w4cOGCGDx9unE6niYiIMC+99NJlN/HVdNPtO++8U22ejh07VruzvzZz5841/fr1M06n03Tu3NmMHj3a7N+/3/P8wYMHzQMPPGCCg4NNQECAGThwoNmwYYPn+ZdfftlERUUZPz8/c8MNN5g//OEP1eavaW3GGLN27VozZMgQ43Q6TXBwsBk0aJB59dVX61wvYLsfH+sHDhwwe/bsMWPGjDGdOnUyTqfTxMTEmMmTJ5uqqiqzc+dOk5ycbEJCQoy/v7+54YYbqt1kX1xcbEaOHGmCgoKMJPPxxx9f8eO/+uqrJi4uzgQGBprg4GBz9913my1btnieP3XqlElLSzNdunQxDofD3HTTTea9997zPP/222+b/v37Gz8/P9OrVy/zu9/9rtr8P/78c8nGjRs96wwMDDS33HKLmTdvXuP+ENEivIwxpjWDCQAAoC58lxAAALAewQKrTZgwodq3Gv7trwkTJrT28gA00HPPPVfrMT1q1KjWXh4sxiUhWK24uFgul6vG54KDg9WtW7ervCIATXH69GmdPn26xuecTqfCw8Ov8opwrSBYAACA9bgkBAAArEewAAAA6xEsAADAegQLAACwHsECAACsR7AAAADrESwAAMB6BAsAALDe/wfDm7JGC2K1IQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import cross_validate\n",
    "\n",
    "X = X_train\n",
    "y = y_train\n",
    "pipe = make_pipeline(\n",
    "    MinMaxScaler(),  # 使用MinMaxScaler()进行特征缩放\n",
    "    LogisticRegression(solver='saga',\n",
    "                       multi_class='auto',\n",
    "                       random_state=42,\n",
    "                       max_iter=5000))  # 创建管道，包括特征缩放和逻辑回归模型\n",
    "param_grid = {\n",
    "    'logisticregression__C': [0.1, 1.0, 10],  # 设置逻辑回归模型的C参数的取值范围\n",
    "    'logisticregression__penalty': ['l2', 'l1']  # 设置逻辑回归模型的penalty参数的取值范围\n",
    "}\n",
    "grid = GridSearchCV(pipe, param_grid=param_grid, cv=3,\n",
    "                    n_jobs=-1)  # 创建网格搜索对象，传入管道和参数网格\n",
    "scores = pd.DataFrame(\n",
    "    cross_validate(grid, X, y, cv=3, n_jobs=-1,\n",
    "                   return_train_score=True))  # 使用交叉验证评估模型，并将结果转换为DataFrame格式\n",
    "scores[['train_score', 'test_score']].boxplot()  # 绘制训练集和测试集得分的箱线图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the Pipeline is 0.96\n"
     ]
    }
   ],
   "source": [
    "pipe.fit(X_train, y_train)\n",
    "accuracy = pipe.score(X_test, y_test)\n",
    "print('Accuracy score of the {} is {:.2f}'.format(pipe.__class__.__name__,\n",
    "                                                  accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以使用`get_params()`检查管道的所有参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'memory': None,\n",
       " 'steps': [('minmaxscaler', MinMaxScaler()),\n",
       "  ('logisticregression',\n",
       "   LogisticRegression(max_iter=5000, random_state=42, solver='saga'))],\n",
       " 'verbose': False,\n",
       " 'minmaxscaler': MinMaxScaler(),\n",
       " 'logisticregression': LogisticRegression(max_iter=5000, random_state=42, solver='saga'),\n",
       " 'minmaxscaler__clip': False,\n",
       " 'minmaxscaler__copy': True,\n",
       " 'minmaxscaler__feature_range': (0, 1),\n",
       " 'logisticregression__C': 1.0,\n",
       " 'logisticregression__class_weight': None,\n",
       " 'logisticregression__dual': False,\n",
       " 'logisticregression__fit_intercept': True,\n",
       " 'logisticregression__intercept_scaling': 1,\n",
       " 'logisticregression__l1_ratio': None,\n",
       " 'logisticregression__max_iter': 5000,\n",
       " 'logisticregression__multi_class': 'auto',\n",
       " 'logisticregression__n_jobs': None,\n",
       " 'logisticregression__penalty': 'l2',\n",
       " 'logisticregression__random_state': 42,\n",
       " 'logisticregression__solver': 'saga',\n",
       " 'logisticregression__tol': 0.0001,\n",
       " 'logisticregression__verbose': 0,\n",
       " 'logisticregression__warm_start': False}"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe.get_params()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，可以将网格搜索称为任何其他分类器以进行预测。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 练习 异构数据：当您使用数字以外的数据时"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pclass</th>\n",
       "      <th>survived</th>\n",
       "      <th>name</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>ticket</th>\n",
       "      <th>fare</th>\n",
       "      <th>cabin</th>\n",
       "      <th>embarked</th>\n",
       "      <th>boat</th>\n",
       "      <th>body</th>\n",
       "      <th>home.dest</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Allen, Miss. Elisabeth Walton</td>\n",
       "      <td>female</td>\n",
       "      <td>29.0000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>24160</td>\n",
       "      <td>211.3375</td>\n",
       "      <td>B5</td>\n",
       "      <td>S</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>St Louis, MO</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Allison, Master. Hudson Trevor</td>\n",
       "      <td>male</td>\n",
       "      <td>0.9167</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>11</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Allison, Miss. Helen Loraine</td>\n",
       "      <td>female</td>\n",
       "      <td>2.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Allison, Mr. Hudson Joshua Creighton</td>\n",
       "      <td>male</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>NaN</td>\n",
       "      <td>135.0</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Allison, Mrs. Hudson J C (Bessie Waldo Daniels)</td>\n",
       "      <td>female</td>\n",
       "      <td>25.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>Montreal, PQ / Chesterville, ON</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   pclass  survived                                             name     sex  \\\n",
       "0       1         1                    Allen, Miss. Elisabeth Walton  female   \n",
       "1       1         1                   Allison, Master. Hudson Trevor    male   \n",
       "2       1         0                     Allison, Miss. Helen Loraine  female   \n",
       "3       1         0             Allison, Mr. Hudson Joshua Creighton    male   \n",
       "4       1         0  Allison, Mrs. Hudson J C (Bessie Waldo Daniels)  female   \n",
       "\n",
       "       age  sibsp  parch  ticket      fare    cabin embarked boat   body  \\\n",
       "0  29.0000      0      0   24160  211.3375       B5        S    2    NaN   \n",
       "1   0.9167      1      2  113781  151.5500  C22 C26        S   11    NaN   \n",
       "2   2.0000      1      2  113781  151.5500  C22 C26        S  NaN    NaN   \n",
       "3  30.0000      1      2  113781  151.5500  C22 C26        S  NaN  135.0   \n",
       "4  25.0000      1      2  113781  151.5500  C22 C26        S  NaN    NaN   \n",
       "\n",
       "                         home.dest  \n",
       "0                     St Louis, MO  \n",
       "1  Montreal, PQ / Chesterville, ON  \n",
       "2  Montreal, PQ / Chesterville, ON  \n",
       "3  Montreal, PQ / Chesterville, ON  \n",
       "4  Montreal, PQ / Chesterville, ON  "
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "data = pd.read_csv('data/titanic_openml.csv', na_values='?')\n",
    "# 使用pd.read_csv函数读取'titanic_openml.csv'文件的内容，并将其中的问号（?）标记为缺失值\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "泰坦尼克号数据集包含分类，文本和数字特征。 我们将使用此数据集来预测乘客是否在泰坦尼克号中幸存下来。\n",
    "\n",
    "让我们将数据拆分为训练和测试集，并将幸存列用作目标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = data['survived']\n",
    "X = data.drop(columns='survived')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，可以尝试使用`LogisticRegression`分类器，看看它的表现有多好。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "could not convert string to float: 'Rekic, Mr. Tido'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[36], line 2\u001b[0m\n\u001b[0;32m      1\u001b[0m clf \u001b[38;5;241m=\u001b[39m LogisticRegression()\n\u001b[1;32m----> 2\u001b[0m \u001b[43mclf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;66;03m#这里肯定会报错。\u001b[39;00m\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:1196\u001b[0m, in \u001b[0;36mLogisticRegression.fit\u001b[1;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[0;32m   1193\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1194\u001b[0m     _dtype \u001b[38;5;241m=\u001b[39m [np\u001b[38;5;241m.\u001b[39mfloat64, np\u001b[38;5;241m.\u001b[39mfloat32]\n\u001b[1;32m-> 1196\u001b[0m X, y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1197\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1198\u001b[0m \u001b[43m    \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1199\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcsr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1200\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1201\u001b[0m \u001b[43m    \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mC\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1202\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccept_large_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msolver\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mliblinear\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msag\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msaga\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1203\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1204\u001b[0m check_classification_targets(y)\n\u001b[0;32m   1205\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclasses_ \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39munique(y)\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\sklearn\\base.py:554\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[1;34m(self, X, y, reset, validate_separately, **check_params)\u001b[0m\n\u001b[0;32m    552\u001b[0m         y \u001b[38;5;241m=\u001b[39m check_array(y, input_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheck_y_params)\n\u001b[0;32m    553\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 554\u001b[0m         X, y \u001b[38;5;241m=\u001b[39m check_X_y(X, y, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheck_params)\n\u001b[0;32m    555\u001b[0m     out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[0;32m    557\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:1104\u001b[0m, in \u001b[0;36mcheck_X_y\u001b[1;34m(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)\u001b[0m\n\u001b[0;32m   1099\u001b[0m         estimator_name \u001b[38;5;241m=\u001b[39m _check_estimator_name(estimator)\n\u001b[0;32m   1100\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m   1101\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mestimator_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m requires y to be passed, but the target y is None\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   1102\u001b[0m     )\n\u001b[1;32m-> 1104\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1105\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1106\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1107\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccept_large_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maccept_large_sparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1108\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1109\u001b[0m \u001b[43m    \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1110\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1111\u001b[0m \u001b[43m    \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_all_finite\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1112\u001b[0m \u001b[43m    \u001b[49m\u001b[43mensure_2d\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_2d\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1113\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_nd\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_nd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1114\u001b[0m \u001b[43m    \u001b[49m\u001b[43mensure_min_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1115\u001b[0m \u001b[43m    \u001b[49m\u001b[43mensure_min_features\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensure_min_features\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1116\u001b[0m \u001b[43m    \u001b[49m\u001b[43mestimator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1117\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mX\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1118\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1120\u001b[0m y \u001b[38;5;241m=\u001b[39m _check_y(y, multi_output\u001b[38;5;241m=\u001b[39mmulti_output, y_numeric\u001b[38;5;241m=\u001b[39my_numeric, estimator\u001b[38;5;241m=\u001b[39mestimator)\n\u001b[0;32m   1122\u001b[0m check_consistent_length(X, y)\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:877\u001b[0m, in \u001b[0;36mcheck_array\u001b[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[0;32m    875\u001b[0m         array \u001b[38;5;241m=\u001b[39m xp\u001b[38;5;241m.\u001b[39mastype(array, dtype, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m    876\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 877\u001b[0m         array \u001b[38;5;241m=\u001b[39m \u001b[43m_asarray_with_order\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mxp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mxp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    878\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ComplexWarning \u001b[38;5;28;01mas\u001b[39;00m complex_warning:\n\u001b[0;32m    879\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m    880\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mComplex data not supported\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(array)\n\u001b[0;32m    881\u001b[0m     ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcomplex_warning\u001b[39;00m\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\sklearn\\utils\\_array_api.py:185\u001b[0m, in \u001b[0;36m_asarray_with_order\u001b[1;34m(array, dtype, order, copy, xp)\u001b[0m\n\u001b[0;32m    182\u001b[0m     xp, _ \u001b[38;5;241m=\u001b[39m get_namespace(array)\n\u001b[0;32m    183\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m xp\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumpy\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumpy.array_api\u001b[39m\u001b[38;5;124m\"\u001b[39m}:\n\u001b[0;32m    184\u001b[0m     \u001b[38;5;66;03m# Use NumPy API to support order\u001b[39;00m\n\u001b[1;32m--> 185\u001b[0m     array \u001b[38;5;241m=\u001b[39m \u001b[43mnumpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    186\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m xp\u001b[38;5;241m.\u001b[39masarray(array, copy\u001b[38;5;241m=\u001b[39mcopy)\n\u001b[0;32m    187\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\pandas\\core\\generic.py:2070\u001b[0m, in \u001b[0;36mNDFrame.__array__\u001b[1;34m(self, dtype)\u001b[0m\n\u001b[0;32m   2069\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__array__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dtype: npt\u001b[38;5;241m.\u001b[39mDTypeLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m-> 2070\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_values\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mValueError\u001b[0m: could not convert string to float: 'Rekic, Mr. Tido'"
     ]
    }
   ],
   "source": [
    "clf = LogisticRegression()\n",
    "clf.fit(X_train, y_train)#这里肯定会报错。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "大多数分类器都设计用于处理数值数据。 因此，我们需要将分类数据转换为数字特征。 最简单的方法是使用`OneHotEncoder`对每个分类特征进行读热编码。 让我们以`sex`与`embarked`列为例。 请注意，我们还会遇到一些缺失的数据。 我们将使用`SimpleImputer`用常量值替换缺失值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 1., 0., 0., 1., 0.],\n",
       "       [0., 1., 1., 0., 0., 0.],\n",
       "       [0., 1., 0., 0., 1., 0.],\n",
       "       ...,\n",
       "       [0., 1., 0., 0., 1., 0.],\n",
       "       [1., 0., 0., 0., 1., 0.],\n",
       "       [1., 0., 0., 0., 1., 0.]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "ohe = make_pipeline(SimpleImputer(strategy='constant'), OneHotEncoder())\n",
    "# 创建一个管道（Pipeline），其中包含两个步骤：\n",
    "# 第一个步骤是SimpleImputer，使用strategy参数设置缺失值填充策略为常数填充\n",
    "# 第二个步骤是OneHotEncoder，用于对分类变量进行独热编码\n",
    "X_encoded = ohe.fit_transform(X_train[['sex', 'embarked']])\n",
    "# 使用fit_transform方法对X_train中的'sex'和'embarked'两个列进行处理\n",
    "# 先使用SimpleImputer对缺失值进行填充，再使用OneHotEncoder进行独热编码\n",
    "X_encoded.toarray()\n",
    "# 将经过独热编码后的稀疏矩阵转换为稠密矩阵，并返回"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这样，可以对分类特征进行编码。 但是，我们也希望标准化数字特征。 因此，我们需要将原始数据分成2个子组并应用不同的预处理：（i）分类数据的独热编；（ii）数值数据的标准缩放(归一化)。 我们还需要处理两种情况下的缺失值： 对于分类列，我们将字符串'`missing_values`'替换为缺失值，该字符串将自行解释为类别。 对于数值数据，我们将用感兴趣的特征的平均值替换缺失的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "col_cat = ['sex', 'embarked']\n",
    "col_num = ['age', 'sibsp', 'parch', 'fare']\n",
    "\n",
    "X_train_cat = X_train[col_cat]\n",
    "X_train_num = X_train[col_num]\n",
    "X_test_cat = X_test[col_cat]\n",
    "X_test_num = X_test[col_num]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "# 导入StandardScaler类，用于数据标准化\n",
    "scaler_cat = make_pipeline(SimpleImputer(strategy='constant'), OneHotEncoder())\n",
    "# 创建一个Pipeline对象，其中包含两个步骤：\n",
    "# 1. 使用SimpleImputer类将缺失值替换为常数值，策略为'constant'\n",
    "# 2. 使用OneHotEncoder类对分类变量进行独热编码\n",
    "X_train_cat_enc = scaler_cat.fit_transform(X_train_cat)\n",
    "# 使用fit_transform方法对训练集的分类变量进行处理，将缺失值替换为常数值并进行独热编码\n",
    "X_test_cat_enc = scaler_cat.transform(X_test_cat)\n",
    "# 使用transform方法对测试集的分类变量进行处理，将缺失值替换为常数值并进行独热编码\n",
    "scaler_num = make_pipeline(SimpleImputer(strategy='mean'), StandardScaler())\n",
    "# 创建一个Pipeline对象，其中包含两个步骤：\n",
    "# 1. 使用SimpleImputer类将缺失值替换为均值，策略为'mean'\n",
    "# 2. 使用StandardScaler类对数值变量进行标准化\n",
    "X_train_num_scaled = scaler_num.fit_transform(X_train_num)\n",
    "# 使用fit_transform方法对训练集的数值变量进行处理，将缺失值替换为均值并进行标准化\n",
    "X_test_num_scaled = scaler_num.transform(X_test_num)\n",
    "# 使用transform方法对测试集的数值变量进行处理，将缺失值替换为均值并进行标准化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy import sparse\n",
    "# 导入sparse模块，用于处理稀疏矩阵\n",
    "# 转为稀疏矩阵\n",
    "X_train_scaled = sparse.hstack((X_train_cat_enc,\n",
    "                                sparse.csr_matrix(X_train_num_scaled)))\n",
    "# 使用hstack函数将经过独热编码后的分类变量和经过标准化后的数值变量进行水平拼接\n",
    "# 使用csr_matrix函数将数值变量转换为压缩稀疏行矩阵格式\n",
    "X_test_scaled = sparse.hstack((X_test_cat_enc,\n",
    "                               sparse.csr_matrix(X_test_num_scaled)))\n",
    "# 使用hstack函数将经过独热编码后的分类变量和经过标准化后的数值变量进行水平拼接\n",
    "# 使用csr_matrix函数将数值变量转换为压缩稀疏行矩阵格式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "转换完成后，我们现在可以组合所有数值的信息。最后，我们使用`LogisticRegression`分类器作为模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy score of the LogisticRegression is 0.7866\n"
     ]
    }
   ],
   "source": [
    "clf = LogisticRegression(solver='lbfgs')\n",
    "# 创建一个LogisticRegression对象，使用solver参数设置求解器为'lbfgs'\n",
    "clf.fit(X_train_scaled, y_train)\n",
    "# 使用fit方法训练模型，传入经过处理后的训练集X_train_scaled和对应的标签y_train\n",
    "accuracy = clf.score(X_test_scaled, y_test)\n",
    "# 使用score方法计算模型在经过处理后的测试集X_test_scaled上的准确率\n",
    "print('Accuracy score of the {} is {:.4f}'.format(clf.__class__.__name__, accuracy))\n",
    "# 打印模型的准确率，使用.format方法将模型类名和准确率格式化输出"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
